from functools import partial

import numpy as np
import jax
import jax.numpy as jnp

from diffgro.diffgro.functions import _loss_llm, _loss_txt


def _manual_loss_fn(context_type, context_target):
    if context_type == 'speed below':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,:3]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, speed - {context_target}))\n\treturn loss"
    elif context_type == 'speed above':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,:3]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, {context_target} - speed))\n\treturn loss"
    elif context_type == 'x-axis faster':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,:3]\n\tspeed = jnp.linalg.norm(act[:,:,0:1], axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, {context_target} - speed))\n\treturn loss"
    elif context_type == 'y-axis faster':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,:3]\n\tspeed = jnp.linalg.norm(act[:,:,1:2], axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, {context_target} - speed))\n\treturn loss"
    elif context_type == 'speed above and below':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,:3]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss = jnp.mean(jnp.maximum(0.0, speed - {context_target[1]})) +  jnp.mean(jnp.maximum(0.0, {context_target[0]} - speed))\n\treturn loss" # below + above
    elif context_type == 'faster':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,:3]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss=-jnp.mean(speed)\n\treturn loss"
    elif context_type == 'slower':
        code = f"def _loss_fn(x, obs_dim):\n\tact = x[:,:,:3]\n\tspeed = jnp.linalg.norm(act, axis=-1)\n\tloss=jnp.mean(speed)\n\treturn loss"
    else:
        raise NotImplementedError

    code = '@partial(jax.jit, static_argnums=(1,))\n' + code
    exec(code, globals())
    return _loss_fn, code


guide_fn_dict = {
    "blank": None,
    "manual": _manual_loss_fn,
    "llm": _loss_llm,
    "txt": _loss_txt,
}
